{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UfYIQNU1WH_H"
   },
   "source": [
    "## Mount drive and set current directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3759,
     "status": "ok",
     "timestamp": 1607430474795,
     "user": {
      "displayName": "Apurva Bhargava",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Ggtud541LE-7b_PbBmTGtGkNn9nRFwEQ3keJswI6Q=s64",
      "userId": "07288249218888651888"
     },
     "user_tz": 300
    },
    "id": "JFf-AvwE0llf",
    "outputId": "c94092fb-89a4-43ec-dbb4-aeec77bd38b0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E:\\Dropbox\\Optimal Training Sets\\Replication File v4\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.chdir(\"..\")\n",
    "print(os.getcwd())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4wsQNWAFxxhq"
   },
   "source": [
    "## Random Pick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "vEZBWNQIxxAu"
   },
   "outputs": [],
   "source": [
    "from random import choices\n",
    "\n",
    "def random_indices(max_obs, num_obs):\n",
    "  return choices(range(max_obs), k=num_obs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UI3LTFbHt8J6"
   },
   "source": [
    "## K-means Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "sZVSiqg53Wxb"
   },
   "outputs": [],
   "source": [
    "from fast_pytorch_kmeans import KMeans\n",
    "import torch\n",
    "from random import choices\n",
    "\n",
    "def kmeans_indices(obs, num_obs):\n",
    "  kmeans = KMeans(n_clusters=num_obs, mode='euclidean', verbose=1)\n",
    "  labels = kmeans.fit_predict(obs)\n",
    "  label_idx_dict = {}\n",
    "  for index, label in enumerate(labels):\n",
    "    label = label.item()\n",
    "    if label in label_idx_dict:\n",
    "      label_idx_dict[label].append(index)\n",
    "    else:\n",
    "      label_idx_dict[label] = [index]\n",
    "  indices = [choices(label_idx_dict[key], k=1)[0] for key in label_idx_dict]\n",
    "  if len(indices) < num_obs:\n",
    "    more_indices = [choices(label_idx_dict[key], k=1)[0] for key in label_idx_dict]\n",
    "    for idx in more_indices:\n",
    "      if idx not in indices:\n",
    "        indices.append(idx)\n",
    "      if len(indices) == num_obs:\n",
    "        break\n",
    "  return indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zWhZ-1MK0LZR"
   },
   "source": [
    "## Farthest Point Sampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "t9Skr5k9wscx"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def farthestPointSampler(dist_matrix, num_obs):\n",
    "  indices = np.zeros(num_obs, dtype=np.int64)\n",
    "  # select two farthest points\n",
    "  indices[0], indices[1] = np.unravel_index(dist_matrix.argmax(), dist_matrix.shape)\n",
    "  for i in range(2, num_obs):\n",
    "    # maximize minimum distance to all points in indices\n",
    "    sorted_indices = np.argsort(np.min(dist_matrix[indices[:i],:], axis=0))[::-1]\n",
    "    #sorted_indices = np.setdiff1d(sorted_indices, indices[:i])\n",
    "    sorted_indices = sorted_indices[~np.in1d(sorted_indices, indices[:i])]\n",
    "    indices[i] = sorted_indices[0]\n",
    "  return indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uzXnP06r8WnM"
   },
   "source": [
    "## Greedy farthest point based on KL Divergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "N_1KLi4lUrKr"
   },
   "outputs": [],
   "source": [
    "# Basis: normal distribution in all embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "Mpfj6uey2o5y"
   },
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "nHKkgvGA2jUW"
   },
   "outputs": [],
   "source": [
    "from scipy.stats import norm\n",
    "from itertools import combinations_with_replacement\n",
    "from itertools import chain\n",
    "import numpy as np\n",
    "\n",
    "def fit_norm(obs):\n",
    "  mu_list = []\n",
    "  sd_list = []\n",
    "  n_obs = obs.shape[0]\n",
    "  for i in range(n_obs):\n",
    "    mu, sd = norm.fit(obs[i,:])\n",
    "    mu_list.append(mu)\n",
    "    sd_list.append(sd)\n",
    "  return mu_list, sd_list\n",
    "\n",
    "def gaussian_kld(mu1, sd1, mu2, sd2):\n",
    "  return np.log(sd2/sd1) + ((sd1**2 + (mu1-mu2)**2) / (2*(sd2**2))) - 0.5\n",
    "\n",
    "def get_kld_matrix(mu_list, sd_list, dataset_name=\"\", embed_type=\"\"):\n",
    "  dshape = len(mu_list)\n",
    "  kld_matrix = np.zeros((dshape, dshape))\n",
    "  looper = combinations_with_replacement(range(dshape), 2)\n",
    "  for i, j in looper:\n",
    "    kld_ij = gaussian_kld(mu_list[i], sd_list[i], mu_list[j], sd_list[j]) + gaussian_kld(mu_list[j], sd_list[j], mu_list[i], sd_list[i])\n",
    "    kld_matrix[i][j] = kld_ij\n",
    "    kld_matrix[j][i] = kld_ij\n",
    "  print('Saving kld matrix...')\n",
    "  np.save(\"data/output/\" +dataset_name+'_kld_'+embed_type, kld_matrix)\n",
    "  #return kld_matrix\n",
    "  return dataset_name+'_kld_'+embed_type+'.npy'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dslNYRxovruu"
   },
   "source": [
    "## Greedy Farthest Point Sampler using Kolmogorov-Smirnov measure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "--ZV29q5knEA"
   },
   "outputs": [],
   "source": [
    "def ks_2samp_faster(data1, data2):\n",
    "    data_all = np.concatenate([data1, data2])\n",
    "    # using searchsorted solves equal data problem\n",
    "    cdf1 = np.searchsorted(data1, data_all, side='right') / data1.shape[0]\n",
    "    cdf2 = np.searchsorted(data2, data_all, side='right') / data2.shape[0]\n",
    "    cddiffs = cdf1 - cdf2\n",
    "    minS = np.clip(-np.min(cddiffs), 0, 1)  # Ensure sign of minS is not negative.\n",
    "    maxS = np.max(cddiffs)\n",
    "    d = max(minS, maxS)\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "YSidcg2DvqUB"
   },
   "outputs": [],
   "source": [
    "def get_ks_matrix(obs, dataset_name=\"\", embed_type=\"\"):\n",
    "  num_obs = len(obs)\n",
    "  ks_matrix = np.zeros((num_obs, num_obs))\n",
    "  obs = np.sort(obs, axis=1)\n",
    "  for i in range(num_obs):\n",
    "    for j in range(i, num_obs):\n",
    "      val = ks_2samp_faster(obs[i], obs[j])\n",
    "      ks_matrix[i][j] = val\n",
    "      ks_matrix[j][i] = val\n",
    "  print('Saving ks matrix...')\n",
    "  np.save(\"data/output/\" +dataset_name+'_ks_'+embed_type, ks_matrix)\n",
    "  return dataset_name+'_ks_'+embed_type+'.npy'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KvSO_DCStI14"
   },
   "source": [
    "## Greedy Farthest Point Sampler using Cosine Distance Matrix\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "bFKa_oZhtH_P"
   },
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cosine\n",
    "def get_cos_matrix(obs, dataset_name=\"\", embed_type=\"\"):\n",
    "  num_obs = len(obs)\n",
    "  cos_matrix = np.zeros((num_obs, num_obs))\n",
    "  for i in range(num_obs):\n",
    "    for j in range(i, num_obs):\n",
    "      val = cosine(obs[i], obs[j])\n",
    "      cos_matrix[i][j] = val\n",
    "      cos_matrix[j][i] = val\n",
    "  print('Saving cosine matrix...')\n",
    "  np.save(\"data/output/\" +dataset_name+'_cos_'+embed_type, cos_matrix)\n",
    "  return dataset_name+'_cos_'+embed_type+'.npy'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-XDmWo6c3dBZ"
   },
   "source": [
    "## D-Optimality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def variancedist(W, S):\n",
    "    try:\n",
    "        R = np.linalg.inv(S.T @ S)\n",
    "    except:\n",
    "        R = np.linalg.inv(S.T @ S + np.eye(S.shape[1]) * 0.001)\n",
    "    D = torch.sum( (W @ R) * W, dim = 1, dtype = torch.float)\n",
    "    return D\n",
    "\n",
    "def dopt(topicmat, k):\n",
    "    index = [np.random.choice(topicmat.shape[0])]\n",
    "    rows = np.array(range(topicmat.shape[0]))\n",
    "    rows = np.delete(rows, index)\n",
    "    S = topicmat[index,:]\n",
    "    W = topicmat[~np.isin(range(topicmat.shape[0]), index)]\n",
    "    # print(S.shape, W.shape)\n",
    "    while len(index) < k:\n",
    "        i = np.argmax(variancedist(W, S))\n",
    "        S = np.vstack((S, W[i,:]))\n",
    "        W = np.delete(W, i, axis=0)\n",
    "        index.append(rows[i])\n",
    "        rows = np.delete(rows, i)\n",
    "    return index\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyOAVDfWph9ykt0naW7FbdAA",
   "collapsed_sections": [],
   "name": "SelectIndices.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
